{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMakerCore Overview of Resource Level Abstractions - XGBoost Training Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Introductions\n",
    "SageMakerCore is a Python SDK designed as a lightweight layer over boto3, the AWS SDK for Python. It is built on the concept of resource level abstractions, where SageMaker Resources are represented as Python classes. This approach enables SageMakerCore to simplify the management of SageMaker Resources and provide a more object-oriented programming interface.\n",
    "\n",
    "### Resource Level Abstraction\n",
    "Resource Level Abstractions can be best understood by examining how the AWS TrainingJob APIs are transfromed into a TrainingJob Python class abstraction in SageMakerCore.\n",
    "\n",
    "For instance, an AWS TrainingJob has the following APIs:\n",
    "1. CreateTrainingJob\n",
    "2. DescribeTrainingJob\n",
    "3. UpdateTrainingJob\n",
    "4. StopTrainingJob\n",
    "5. ListTrainingJobs\n",
    "\n",
    "In SageMakerCore, these APIs are encapsulated within a TrainingJob class that exposes these operations as methods and attributes. The details of the TrainingJob class are below:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "class TrainingJob(Base):\n",
    "    # Class attributes are mapped to describe_training_job response\n",
    "    training_job_name: str\n",
    "    training_job_arn: Optional[str] = Unassigned()\n",
    "    tuning_job_arn: Optional[str] = Unassigned()\n",
    "    labeling_job_arn: Optional[str] = Unassigned()\n",
    "    auto_ml_job_arn: Optional[str] = Unassigned()\n",
    "    model_artifacts: Optional[ModelArtifacts] = Unassigned()\n",
    "    training_job_status: Optional[str] = Unassigned()\n",
    "    ...\n",
    "\n",
    "    @classmethod\n",
    "    def create():       # Calls `create_training_job`\n",
    "\n",
    "    @classmethod\n",
    "    def get():          # Calls `describe_training_job`\n",
    "\n",
    "    @classmethod\n",
    "    def get_all():      # Calls `list_training_job`\n",
    "\n",
    "    \n",
    "    def update():       # Calls `update_training_job`\n",
    "\n",
    "\n",
    "    def stop():         # Calls `stop_training_job`\n",
    "\n",
    "\n",
    "    def refresh():      # Calls `describe_training_job` and refreshes instance attributes\n",
    "\n",
    "\n",
    "    def wait():         # Calls `describe_training_job` and waits for TrainingJob to enter terminal state\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparing Boto3 and SageMakerCore SDKs\n",
    "\n",
    "In this notebook, we create an AWS TrainingJob to train an XGBoost Container. We will be using both Boto3 and the SageMakerCore SDKs with the goal of highlighting and comparing the differences in user experience for performing operations such as creating, updating, waiting, and listing AWS TrainingJobs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Latest SageMakerCore\n",
    "All SageMakerCore beta distributions will be released to a private s3 bucket. After being allowlisted, run the cells below to install the latest version of SageMakerCore from `s3://sagemaker-core-beta-artifacts/sagemaker_core-latest.tar.gz`\n",
    "\n",
    "Ensure you are using a kernel with python version >=3.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uninstall previous version of sagemaker-core and restart kernel\n",
    "!pip uninstall sagemaker-core -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the latest version of sagemaker-core\n",
    "\n",
    "!pip install sagemaker-core --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the version of sagemaker-core\n",
    "!pip show -v sagemaker-core"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Additional Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install additional packages\n",
    "\n",
    "!pip install -U scikit-learn pandas boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup\n",
    "\n",
    "Let's start by specifying:\n",
    "- AWS region.\n",
    "- The IAM role arn used to give learning and hosting access to your data. Ensure your enviornment has AWS Credentials configured.\n",
    "- The S3 bucket that you want to use for storing training and model data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.helper.session_helper import Session, get_execution_role\n",
    "from rich import print\n",
    "\n",
    "# Get region, role, bucket\n",
    "\n",
    "sagemaker_session = Session()\n",
    "region = sagemaker_session.boto_region_name\n",
    "role = get_execution_role()\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "print(role)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Prepare Dataset\n",
    "For this example, we will be using the IRIS data set from `sklearn.datasets` to train our XGBoost container."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "# Get IRIS Data\n",
    "\n",
    "iris = load_iris()\n",
    "iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "iris_df['target'] = iris.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Prepare Data\n",
    "\n",
    "os.makedirs('./data', exist_ok=True)\n",
    "\n",
    "iris_df = iris_df[['target'] + [col for col in iris_df.columns if col != 'target']]\n",
    "\n",
    "train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42)\n",
    "\n",
    "train_data.to_csv('./data/train.csv', index=False, header=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload Data to S3\n",
    "In this step, we will upload the train and test data to the S3 bucket configured earlier using `sagemaker_session.default_bucket()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Upload Data\n",
    "\n",
    "prefix = \"DEMO-scikit-iris\"\n",
    "TRAIN_DATA = \"train.csv\"\n",
    "DATA_DIRECTORY = \"data\"\n",
    "\n",
    "train_input = sagemaker_session.upload_data(\n",
    "    DATA_DIRECTORY, bucket=bucket, key_prefix=\"{}/{}\".format(prefix, DATA_DIRECTORY)\n",
    ")\n",
    "\n",
    "s3_input_path = \"s3://{}/{}/data/{}\".format(bucket, prefix, TRAIN_DATA)\n",
    "s3_output_path = \"s3://{}/{}/output\".format(bucket, prefix)\n",
    "\n",
    "print(s3_input_path)\n",
    "print(s3_output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetch the XGBoost Image URI\n",
    "In this step, we will fetch the XGBoost Image URI we will use as an input parameter when creating an AWS TrainingJob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image name is hardcoded here\n",
    "# Image name can be programatically got by using sagemaker package and calling image_uris.retrieve\n",
    "# Since that is a high level abstraction that has multiple dependencies, the image URIs functionalities will live in sagemaker (V2)\n",
    "\n",
    "image = \"433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create TrainingJob with Boto3\n",
    "With the necessary setup completed, we can now create an AWS TrainingJob. First we will begin by creating a TrainingJob with Boto3 to understand what the experience is like when interecting directly with low-level APIs through Boto3.\n",
    "\n",
    "When executing the following cells there are a few things to note about the experience with Boto3:\n",
    "1. Boto3 dynamically generates the API operation methods like `create_training_job`. When a client is instantiated, the methods are generated from the JSON service model description and are not statically coded into the boto3 library.\n",
    "2. Boto3 returns a JSON response. As a result, users must either be familiar with the structure of these responses or refer to the documentation to parse them correctly.\n",
    "3. Boto3 client methods expect keyword arguments. Similar to the experience with JSON response, users must be familiar with what keyword argumnets are expected or refer to the documentation to pass them correctly.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create TrainingJob with Boto3\n",
    "\n",
    "import time\n",
    "import boto3\n",
    "\n",
    "client = boto3.client('sagemaker')\n",
    "job_name_boto = 'xgboost-iris-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n",
    "\n",
    "response = client.create_training_job(\n",
    "    TrainingJobName=job_name_boto,\n",
    "    HyperParameters={\n",
    "        'objective': 'multi:softmax',\n",
    "        'num_class': '3',\n",
    "        'num_round': '10',\n",
    "        'eval_metric': 'merror'\n",
    "    },\n",
    "    AlgorithmSpecification={\n",
    "        'TrainingImage': image,\n",
    "        'TrainingInputMode': 'File'\n",
    "    },\n",
    "    RoleArn=role,\n",
    "    InputDataConfig=[\n",
    "        {\n",
    "            'ChannelName': 'train',\n",
    "            'ContentType': 'csv',\n",
    "            'DataSource': {\n",
    "                'S3DataSource': {\n",
    "                    'S3DataType': 'S3Prefix',\n",
    "                    'S3Uri': s3_input_path,\n",
    "                    'S3DataDistributionType': 'FullyReplicated'\n",
    "                }\n",
    "            },\n",
    "            'CompressionType': 'None',\n",
    "            'RecordWrapperType': 'None'\n",
    "        }\n",
    "    ],\n",
    "    OutputDataConfig={\n",
    "        'S3OutputPath': s3_output_path\n",
    "    },\n",
    "    ResourceConfig={\n",
    "        'InstanceType': 'ml.m4.xlarge',\n",
    "        'InstanceCount': 1,\n",
    "        'VolumeSizeInGB': 30\n",
    "    },\n",
    "    StoppingCondition={\n",
    "        'MaxRuntimeInSeconds': 600\n",
    "    }\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wait for TrainingJob with Boto3\n",
    "When a user creates a TrainingJob it is often the case that they would wish to wait on the TrainingJob to complete. Below is an example of how a user wait on a TrainingJob using Boto3. Notebly, this requires creating some logic to poll the TrainingJob using `describe_training_job` until the `TrainingJobStatus` is `'Failed'`, `'Completed'`, or `'Stopped'`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for TrainingJob with Boto3\n",
    "import time\n",
    "\n",
    "while True:\n",
    "    response = client.describe_training_job(TrainingJobName=job_name_boto)\n",
    "    status = response['TrainingJobStatus']\n",
    "    if status in ['Failed', 'Completed', 'Stopped']:\n",
    "        if status == 'Failed':\n",
    "            print(response['FailureReason'])\n",
    "        break\n",
    "    print(\"-\", end=\"\")\n",
    "    time.sleep(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create TrainingJob with SageMakerCore\n",
    "In this step we will use SageMakerCore to create a TrainingJob to understand what experience the object-oriented resource level abstractions provide for users.\n",
    "\n",
    "When executing the following cells, there are a few things to note about the experience with SageMakerCore:\n",
    "1. SageMakerCore generates Python classes and methods from the service model JSON, similar to Boto3. However, this generation is done prior to a release, resulting in a statically coded interface in the library.\n",
    "2. SageMakerCore adopts an object-oriented approach, providing users with clear visibility of available methods and attributes through type hinting and IDE IntelliSense\n",
    "3. Instead of returning JSON responses like Boto3, SageMakerCore returns objects. This allows users to access response attributes directly from the returned object, eliminating the need to parse JSON or refer to the documentation for structure details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create TrainingJob with SageMakerCore\n",
    "\n",
    "import time\n",
    "from sagemaker.core.resources import TrainingJob, AlgorithmSpecification, Channel, DataSource, S3DataSource, \\\n",
    "    OutputDataConfig, ResourceConfig, StoppingCondition\n",
    "\n",
    "job_name_v3 = 'xgboost-iris-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n",
    "\n",
    "training_job = TrainingJob.create(\n",
    "    training_job_name=job_name_v3,\n",
    "    hyper_parameters={\n",
    "        'objective': 'multi:softmax',\n",
    "        'num_class': '3',\n",
    "        'num_round': '10',\n",
    "        'eval_metric': 'merror'\n",
    "    },\n",
    "    algorithm_specification=AlgorithmSpecification(\n",
    "        training_image=image,\n",
    "        training_input_mode='File'\n",
    "    ),\n",
    "    role_arn=role,\n",
    "    input_data_config=[\n",
    "        Channel(\n",
    "            channel_name='train',\n",
    "            content_type='csv',\n",
    "            compression_type='None',\n",
    "            record_wrapper_type='None',\n",
    "            data_source=DataSource(\n",
    "                s3_data_source=S3DataSource(\n",
    "                    s3_data_type='S3Prefix',\n",
    "                    s3_uri=s3_input_path,\n",
    "                    s3_data_distribution_type='FullyReplicated'\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "    ],\n",
    "    output_data_config=OutputDataConfig(\n",
    "        s3_output_path=s3_output_path\n",
    "    ),\n",
    "    resource_config=ResourceConfig(\n",
    "        instance_type='ml.m4.xlarge',\n",
    "        instance_count=1,\n",
    "        volume_size_in_gb=30\n",
    "    ),\n",
    "    stopping_condition=StoppingCondition(\n",
    "        max_runtime_in_seconds=600\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wait for TrainingJob with SageMakerCore\n",
    "In SageMakerCore, the logic required to wait on a resource is abstracted away using a `wait()` method. As a result, a user can directly call the `wait()` method on a TrainingJob object instance like below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for TrainingJob with SageMakerCore\n",
    "\n",
    "training_job.wait(logs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## List TrainingJobs with Boto3\n",
    "When a user lists TrainingJobs, there are 2 main approaches provided by Boto3. \n",
    "\n",
    "1. The first is calling `list_training_jobs` directly and implementing some logic to handle the NextToken provided in the response to enable pagination.\n",
    "2. The second is by utilizing the Boto3 `get_paginator` method to get a paginator that encapsulates the NextToken and simplifies the logic required.\n",
    "\n",
    "Both approaches are shown below. Although the boto3 provided paginator simplifies the logic over using a NextToken, in both cases the user must understand the structure of the list responses or refer to the docs (ie, understand to access TrainingJobSummaries by doing `response[\"TrainingJobSummaries\"]`)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List TrainingJobs with Boto3\n",
    "import datetime\n",
    "import boto3\n",
    "\n",
    "client = boto3.client('sagemaker')\n",
    "\n",
    "creation_time_after = datetime.datetime.now() - datetime.timedelta(days=1)\n",
    "\n",
    "# List TrainingJobs with NextToken\n",
    "next_token = None\n",
    "while True:\n",
    "    if next_token:\n",
    "        response = client.list_training_jobs(CreationTimeAfter=creation_time_after, NextToken=next_token)\n",
    "    else: \n",
    "        response = client.list_training_jobs(CreationTimeAfter=creation_time_after)\n",
    "    \n",
    "    for job in response['TrainingJobSummaries']:\n",
    "        print(job['TrainingJobName'], job[\"TrainingJobStatus\"])\n",
    "        \n",
    "    next_token = response.get('NextToken')\n",
    "    \n",
    "    if not next_token:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import boto3\n",
    "\n",
    "client = boto3.client('sagemaker')\n",
    "creation_time_after = datetime.datetime.now() - datetime.timedelta(days=1)\n",
    "\n",
    "# List TrainingJobs with Boto3 Paginator\n",
    "paginator = client.get_paginator('list_training_jobs')\n",
    "for response in paginator.paginate(CreationTimeAfter=creation_time_after):\n",
    "    for job in response['TrainingJobSummaries']:\n",
    "        print(job['TrainingJobName'], job[\"TrainingJobStatus\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## List TrainingJobs with SageMakerCore\n",
    "\n",
    "In SageMakerCore, listing is done similar to the boto3 paginator approach but instead with a `ResourceIterator` which implements the python iterator protocol to instantiate and return resource objects only as they are accessed.\n",
    "\n",
    "\n",
    "Below, is an example of how the `get_all()` method would be used to list TrainingJobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List TrainingJobs with SageMakerCore\n",
    "import datetime\n",
    "from sagemaker.core.resources import TrainingJob\n",
    "\n",
    "creation_time_after = datetime.datetime.now() - datetime.timedelta(days=1)\n",
    "\n",
    "resource_iterator = TrainingJob.get_all(creation_time_after=creation_time_after)\n",
    "for job in resource_iterator:\n",
    "    print(job.training_job_name, job.training_job_status)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Delete All SageMaker Resources\n",
    "The following code block will call the delete() method for any SageMaker Core Resources created during the execution of this notebook which were assigned to local or global variables. If you created any additional deleteable resources without assigning the returning object to a unique variable, you will need to delete the resource manually by doing something like:\n",
    "\n",
    "```python\n",
    "resource = Resource.get(\"resource-name\")\n",
    "resource.delete()\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete any sagemaker core resource objects created in this notebook\n",
    "def delete_all_sagemaker_resources():\n",
    "    all_objects = list(locals().values()) + list(globals().values())\n",
    "    deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n",
    "    \n",
    "    for obj in deletable_objects:\n",
    "        obj.delete()\n",
    "        \n",
    "delete_all_sagemaker_resources()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
